{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DB Setup\n",
    "\n",
    "We assume you already have a sqlite database ready."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import sqlite3"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### database setup, uncomment if database doesn't exist"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "download data from https://www.kaggle.com/datasets/ranadeep/credit-risk-dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.read_csv('data/loan.csv')\n",
    "# conn = sqlite3.connect('loan.db')\n",
    "# df.to_sql('loan', conn, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATABASE = \"loan.db\"\n",
    "TABLES = []  # list of tables to load or [] to load all tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading tables: ['loan']\n"
     ]
    }
   ],
   "source": [
    "from db_connectors import SQLiteConnector\n",
    "from prompt_formatters import RajkumarFormatter\n",
    "\n",
    "# Get the connector and formatter\n",
    "sqlite_connector = SQLiteConnector(\n",
    "    database_path=DATABASE\n",
    ")\n",
    "sqlite_connector.connect()\n",
    "if len(TABLES) <= 0:\n",
    "    TABLES.extend(sqlite_connector.get_tables())\n",
    "\n",
    "print(f\"Loading tables: {TABLES}\")\n",
    "\n",
    "db_schema = [sqlite_connector.get_schema(table) for table in TABLES]\n",
    "formatter = RajkumarFormatter(db_schema)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Setup\n",
    "\n",
    "In a separate screen or window, first install [Manifest](https://github.com/HazyResearch/manifest)\n",
    "```bash\n",
    "pip install manifest-ml\\[all\\]\n",
    "```\n",
    "\n",
    "Then run\n",
    "```bash\n",
    "python3 -m manifest.api.app \\\n",
    "    --model_type huggingface \\\n",
    "    --model_generation_type text-generation \\\n",
    "    --model_name_or_path NumbersStation/nsql-6B \\\n",
    "    --device 0\n",
    "```\n",
    "\n",
    "If successful, you will see an output like\n",
    "```bash\n",
    "* Running on http://127.0.0.1:5000\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from manifest import Manifest\n",
    "\n",
    "manifest_client = Manifest(client_name=\"huggingface\", client_connection=\"http://127.0.0.1:5000\")\n",
    "\n",
    "def get_sql(instruction: str, max_tokens: int = 300) -> str:\n",
    "    prompt = formatter.format_prompt(instruction)\n",
    "    res = manifest_client.run(prompt, max_tokens=max_tokens)\n",
    "    return formatter.format_model_output(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "questions = [\n",
    "    \"What is the member_id for the user who has the highest risk?\",         # illogical question\n",
    "    \"How many rows are in the loan database?\",                              # basic \n",
    "    \"What is the breakdown of the grade column?\",                           # basic\n",
    "                                                                            # basic longer\n",
    "    \"Count the number of rows where the loan status is any of the following ['Charged Off', 'Default', 'Does not meet the credit policy. Status:Charged Off', 'In Grace Period', 'Late (31-120 days)', 'Late (16-30 days)']\",\n",
    "                                                                            # limitation: alter table\n",
    "    \"Alter table to add a column titled 'class' that is 1 if loan_status is any of the following ['Charged Off', 'Default', 'Does not meet the credit policy. Status:Charged Off', 'In Grace Period', 'Late (31-120 days)', 'Late (16-30 days)'] and 0 otherwise.\",\n",
    "    \"How many loans where risky is 1 are there?\",                           # basic\n",
    "    \"A loan is risky when risky =  1. Count the risky loans, grouped by term length.\",\n",
    "    \"What is the maximum loan amount?\", # TODO: top ten for some filter\n",
    "                                        # TODO: split \n",
    "    \"Split the original table based on term length?\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SELECT term, COUNT(*) FROM loan GROUP BY term;\n"
     ]
    }
   ],
   "source": [
    "sql = get_sql(questions[8])\n",
    "print(sql)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "         term  COUNT(*)\n",
      "0   36 months    621125\n",
      "1   60 months    266254\n"
     ]
    }
   ],
   "source": [
    "print(sqlite_connector.run_sql_as_df(sql))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### TODO: \n",
    "1. Research architecture of NSQL.\n",
    "2. How did nsql perform in benchmarks?\n",
    "3. How was nsql fine-tuned?\n",
    "4. Transfer sqlite to db2 for better scaling."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add Riskiness to Database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# risky_statuses = ['Charged Off', 'Default', 'Does not meet the credit policy. Status:Charged Off', 'In Grace Period',\n",
    "#                       'Late (31-120 days)', 'Late (16-30 days)']\n",
    "\n",
    "# df['risky'] = df.loan_status.apply(lambda x: int(x in risky_statuses))\n",
    "\n",
    "# df.to_sql('loan', conn, if_exists='replace', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dbt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
